FEAT Expand memory interface and models for attack results#1420
FEAT Expand memory interface and models for attack results#1420romanlutz wants to merge 1 commit intoAzure:mainfrom
Conversation
romanlutz
commented
Mar 1, 2026
- Add conversation_stats model and attack_result extensions
- Add get_attack_results with filtering by harm categories, labels, attack type, and converter types to memory interface
- Implement SQLite-specific JSON filtering for attack results
- Add memory_models field for targeted_harm_categories
- Add prompt_metadata support to openai image/video/response targets
- Fix missing return statements in SQLite harm_category and label filters
- Add conversation_stats model and attack_result extensions - Add get_attack_results with filtering by harm categories, labels, attack type, and converter types to memory interface - Implement SQLite-specific JSON filtering for attack results - Add memory_models field for targeted_harm_categories - Add prompt_metadata support to openai image/video/response targets - Fix missing return statements in SQLite harm_category and label filters Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR expands PyRIT’s memory/DB layer to support richer attack-result querying and lightweight conversation summaries, and adds stricter “single-turn only” validation for OpenAI image/video targets.
Changes:
- Adds a
ConversationStatsmodel and newget_conversation_stats()memory API with SQLite/Azure SQL implementations. - Extends
AttackResultwith a persisted row id (attack_result_id), adds update/dedup behavior around attack results, and expands filtering options (attack type / converter types / harm categories / labels). - Updates OpenAI targets (base class adjustments + image/video single-turn validation) and adds unit tests for new behaviors/regressions.
Reviewed changes
Copilot reviewed 16 out of 16 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/unit/target/test_video_target.py | Adds coverage for rejecting multi-turn usage in video target validation. |
| tests/unit/target/test_image_target.py | Adds coverage for rejecting multi-turn usage in image target. |
| tests/unit/memory/test_sqlite_memory.py | Adds tests for the new get_conversation_stats() aggregation behavior. |
| tests/unit/memory/memory_interface/test_interface_attack_results.py | Updates/extends attack-result interface tests (dedup + update regressions + renamed filters). |
| pyrit/prompt_target/openai/openai_video_target.py | Enforces single-turn conversations for video target requests. |
| pyrit/prompt_target/openai/openai_image_target.py | Enforces single-turn conversations for image target requests. |
| pyrit/prompt_target/openai/openai_target.py | Switches OpenAI base target to inherit from PromptTarget (not PromptChatTarget). |
| pyrit/prompt_target/openai/openai_realtime_target.py | Updates realtime target inheritance to retain chat-target semantics. |
| pyrit/prompt_target/openai/openai_response_target.py | Attempts to enrich identifier params for response target (currently introduces a runtime kwarg mismatch). |
| pyrit/models/conversation_stats.py | Introduces the ConversationStats dataclass used by memory stats API. |
| pyrit/models/attack_result.py | Adds attack_result_id field to represent DB-assigned row id. |
| pyrit/models/init.py | Re-exports ConversationStats. |
| pyrit/memory/sqlite_memory.py | Implements JSON filtering tweaks + adds get_conversation_stats() for SQLite. |
| pyrit/memory/azure_sql_memory.py | Adds get_conversation_stats() for Azure SQL and hardens _update_entries behavior. |
| pyrit/memory/memory_models.py | Maps DB primary key into AttackResult.attack_result_id when materializing domain objects. |
| pyrit/memory/memory_interface.py | Expands/renames attack result query APIs, adds get_conversation_stats(), and introduces attack-result update helpers + dedup. |
| "reasoning_effort": self._reasoning_effort, | ||
| "reasoning_summary": self._reasoning_summary, | ||
| }, | ||
| target_specific_params=specific_params, | ||
| ) |
There was a problem hiding this comment.
_build_identifier() passes target_specific_params=... into _create_identifier(), but PromptTarget._create_identifier() only accepts params and children. This will raise TypeError: _create_identifier() got an unexpected keyword argument 'target_specific_params' at runtime. Either merge these values into params, or extend _create_identifier() (and any overrides) to accept and correctly handle target_specific_params.
| from contextlib import closing | ||
|
|
||
| with closing(self.get_session()) as session: | ||
| from sqlalchemy.exc import SQLAlchemyError | ||
|
|
There was a problem hiding this comment.
add_attack_results_to_memory() introduces inline imports (contextlib.closing, SQLAlchemyError). This repo enables Ruff E402 (imports must be at top of module), and these imports aren’t for breaking a circular dependency. Move these imports to the module top (or reuse existing module-level imports) to avoid lint failures and keep imports consistent.
|
|
||
| Raises: | ||
| ValueError: If update_fields is empty. |
There was a problem hiding this comment.
The docstring says update_attack_result() raises ValueError when update_fields is empty, but the implementation never checks for that case. Either add an explicit if not update_fields: raise ValueError(...) or update the docstring so it matches the behavior.
| Raises: | |
| ValueError: If update_fields is empty. |
| entries: MutableSequence[AttackResultEntry] = self._query_entries( | ||
| AttackResultEntry, | ||
| conditions=AttackResultEntry.id == attack_result_id, |
There was a problem hiding this comment.
update_attack_result_by_id() compares AttackResultEntry.id (a UUID column) to a str (attack_result_id). Depending on the SQLAlchemy type, this can result in no matches even when the row exists. Convert attack_result_id to uuid.UUID (and handle invalid UUID strings) before building the condition.
| entries: MutableSequence[AttackResultEntry] = self._query_entries( | |
| AttackResultEntry, | |
| conditions=AttackResultEntry.id == attack_result_id, | |
| try: | |
| attack_result_uuid = uuid.UUID(attack_result_id) | |
| except (ValueError, TypeError): | |
| logger.warning( | |
| "Invalid attack_result_id '%s' passed to update_attack_result_by_id", | |
| attack_result_id, | |
| ) | |
| return False | |
| entries: MutableSequence[AttackResultEntry] = self._query_entries( | |
| AttackResultEntry, | |
| conditions=AttackResultEntry.id == attack_result_uuid, |
| objective: Optional[str] = None, | ||
| objective_sha256: Optional[Sequence[str]] = None, | ||
| outcome: Optional[str] = None, | ||
| attack_class: Optional[str] = None, | ||
| converter_classes: Optional[Sequence[str]] = None, | ||
| attack_type: Optional[str] = None, | ||
| converter_types: Optional[Sequence[str]] = None, |
There was a problem hiding this comment.
This changes the public get_attack_results() filter API from attack_class/converter_classes to attack_type/converter_types. There are still call sites using the old keyword args (e.g., pyrit/backend/services/attack_service.py passes attack_class= and converter_classes=), which will raise TypeError at runtime. Either update those call sites in this PR, or keep backward-compatible aliases (deprecated) that map old names to the new ones.
| def get_unique_attack_type_names(self) -> list[str]: | ||
| """ | ||
| Return sorted unique attack class names from all stored attack results. | ||
| Return sorted unique attack type names from all stored attack results. | ||
|
|
||
| Extracts class_name from the attack_identifier JSON column via a |
There was a problem hiding this comment.
Renaming get_unique_attack_class_names() / get_unique_converter_class_names() to *_type_names() is a breaking change for existing call sites (e.g., pyrit/backend/services/attack_service.py still calls the old method names). Please update those usages in this PR, or provide backward-compatible wrapper methods with deprecation warnings.
| request = message.message_pieces[0] | ||
| messages = self._memory.get_conversation(conversation_id=request.conversation_id) |
There was a problem hiding this comment.
_validate_request() pulls the conversation_id from message.message_pieces[0], but this method already identified the single text piece (text_piece) above. Using text_piece.conversation_id avoids relying on message-piece ordering and keeps the validation consistent with the rest of the method.
| request = message.message_pieces[0] | |
| messages = self._memory.get_conversation(conversation_id=request.conversation_id) | |
| messages = self._memory.get_conversation(conversation_id=text_piece.conversation_id) |